//Decoder.scala
package mycore

import chisel3._
import chisel3.util._

class Decoder extends Module{
  val io = IO(new Bundle{
	val inst = Input(UInt(32.W))
	val isa = Output(new ISA)
	val imm = Output(new IMM)
	val csr = Output(new Bool())
	val wen = Output(new Bool())
  })

	//Instruction Formats: I
	io.isa.SLLI		:= (io.inst === BitPat("b000000??????_?????_001_?????_0010011"))
	io.isa.SLLIW	:= (io.inst === BitPat("b000000??????_?????_001_?????_0011011"))
	io.isa.SRLI		:= (io.inst === BitPat("b000000??????_?????_101_?????_0010011"))
	io.isa.SRLIW	:= (io.inst === BitPat("b000000??????_?????_101_?????_0011011"))
	io.isa.SRAI		:= (io.inst === BitPat("b010000??????_?????_101_?????_0010011"))
	io.isa.SRAIW	:= (io.inst === BitPat("b010000??????_?????_101_?????_0011011"))
	
	io.isa.ADDI		:= (io.inst === BitPat("b????????????_?????_000_?????_0010011"))
	io.isa.ADDIW	:= (io.inst === BitPat("b????????????_?????_000_?????_0011011"))

	io.isa.XORI		:= (io.inst === BitPat("b????????????_?????_100_?????_0010011"))
	io.isa.ORI		:= (io.inst === BitPat("b????????????_?????_110_?????_0010011"))
	io.isa.ANDI		:= (io.inst === BitPat("b????????????_?????_111_?????_0010011"))

	io.isa.SLTI		:= (io.inst === BitPat("b????????????_?????_010_?????_0010011"))
	io.isa.SLTIU	:= (io.inst === BitPat("b????????????_?????_011_?????_0010011"))

	io.isa.JALR		:= (io.inst === BitPat("b????????????_?????_000_?????_1100111"))

	io.isa.FENCE 		:= (io.inst === BitPat("b0000????????_00000_000_00000_0001111"))
	io.isa.FENCE_I	:= (io.inst === BitPat("b000000000000_00000_001_00000_0001111"))
	
	io.isa.ECALL	:= (io.inst === BitPat("b000000000000_00000_000_00000_1110011"))
	io.isa.EBREAK	:= (io.inst === BitPat("b000000000001_00000_000_00000_1110011"))

	io.isa.CSRRW	:= (io.inst === BitPat("b????????????_?????_001_?????_1110011"))
	io.isa.CSRRWI	:= (io.inst === BitPat("b????????????_?????_101_?????_1110011"))
	io.isa.CSRRS	:= (io.inst === BitPat("b????????????_?????_010_?????_1110011"))
	io.isa.CSRRSI	:= (io.inst === BitPat("b????????????_?????_110_?????_1110011"))
	io.isa.CSRRC	:= (io.inst === BitPat("b????????????_?????_011_?????_1110011"))
	io.isa.CSRRCI	:= (io.inst === BitPat("b????????????_?????_111_?????_1110011"))

	io.isa.LD		:= (io.inst === BitPat("b????????????_?????_011_?????_0000011"))
	io.isa.LW		:= (io.inst === BitPat("b????????????_?????_010_?????_0000011"))
	io.isa.LWU	:= (io.inst === BitPat("b????????????_?????_110_?????_0000011"))
	io.isa.LH		:= (io.inst === BitPat("b????????????_?????_001_?????_0000011"))
	io.isa.LHU	:= (io.inst === BitPat("b????????????_?????_101_?????_0000011"))
	io.isa.LB		:= (io.inst === BitPat("b????????????_?????_000_?????_0000011"))
	io.isa.LBU	:= (io.inst === BitPat("b????????????_?????_100_?????_0000011"))

	//Instruction Formats: R
	io.isa.SLL	:= (io.inst === BitPat("b000000??????_?????_001_?????_0110011"))
	io.isa.SLLW	:= (io.inst === BitPat("b000000??????_?????_001_?????_0111011"))
	io.isa.SRL	:= (io.inst === BitPat("b000000??????_?????_101_?????_0110011"))
	io.isa.SRLW	:= (io.inst === BitPat("b000000??????_?????_101_?????_0111011"))
	io.isa.SRA	:= (io.inst === BitPat("b010000??????_?????_101_?????_0110011"))
	io.isa.SRAW	:= (io.inst === BitPat("b010000??????_?????_101_?????_0111011"))

	io.isa.ADD	:= (io.inst === BitPat("b000000??????_?????_000_?????_0110011"))
	io.isa.ADDW	:= (io.inst === BitPat("b000000??????_?????_000_?????_0111011"))
	io.isa.SUB	:= (io.inst === BitPat("b010000??????_?????_000_?????_0110011"))
	io.isa.SUBW	:= (io.inst === BitPat("b010000??????_?????_000_?????_0111011"))

	io.isa.XOR	:= (io.inst === BitPat("b000000??????_?????_100_?????_0110011"))
	io.isa.OR		:= (io.inst === BitPat("b000000??????_?????_110_?????_0110011"))
	io.isa.AND	:= (io.inst === BitPat("b000000??????_?????_111_?????_0110011"))

	io.isa.SLT	:= (io.inst === BitPat("b000000??????_?????_010_?????_0110011"))
	io.isa.SLTU	:= (io.inst === BitPat("b000000??????_?????_011_?????_0110011"))

	io.isa.MRET				:= (io.inst === BitPat("b0011000_00010_00000_000_00000_1110011"))
	io.isa.SRET				:= (io.inst === BitPat("b0001000_00010_00000_000_00000_1110011"))
	io.isa.WFI				:= (io.inst === BitPat("b0001000_00101_00000_000_00000_1110011"))
	io.isa.SFENCE_VMA	:= (io.inst === BitPat("b0001001_?????_?????_000_00000_1110011"))

	//Instruction Formats: B/S
	io.isa.BEQ	:= (io.inst === BitPat("b???????_?????_?????_000_?????_1100011"))
  io.isa.BNE	:= (io.inst === BitPat("b???????_?????_?????_001_?????_1100011"))
  io.isa.BLT	:= (io.inst === BitPat("b???????_?????_?????_100_?????_1100011"))
  io.isa.BGE	:= (io.inst === BitPat("b???????_?????_?????_101_?????_1100011"))
  io.isa.BLTU	:= (io.inst === BitPat("b???????_?????_?????_110_?????_1100011"))
  io.isa.BGEU	:= (io.inst === BitPat("b???????_?????_?????_111_?????_1100011"))
  io.isa.SD		:= (io.inst === BitPat("b???????_?????_?????_011_?????_0100011"))
  io.isa.SW		:= (io.inst === BitPat("b???????_?????_?????_010_?????_0100011"))
  io.isa.SH		:= (io.inst === BitPat("b???????_?????_?????_001_?????_0100011"))
  io.isa.SB		:= (io.inst === BitPat("b???????_?????_?????_000_?????_0100011"))

	//Instruction Formats: J/U
	io.isa.LUI		:= (io.inst === BitPat("b?????????????????????_?????_0110111"))
	io.isa.AUIPC	:= (io.inst === BitPat("b?????????????????????_?????_0010111"))
	io.isa.JAL		:= (io.inst === BitPat("b?????????????????????_?????_1101111"))

	val Arithmetic 	= io.isa.ADD || io.isa.ADDW || io.isa.ADDI || io.isa.ADDIW || io.isa.SUB || io.isa.SUBW || io.isa.LUI || io.isa.AUIPC
	val Logical 		= io.isa.XOR || io.isa.XORI || io.isa.OR || io.isa.ORI || io.isa.AND || io.isa.ANDI
	val Shifts 			= io.isa.SLL || io.isa.SLLI || io.isa.SLLW || io.isa.SLLIW || io.isa.SRL || io.isa.SRLI || io.isa.SRLW || io.isa.SRLIW || io.isa.SRA || io.isa.SRAI || io.isa.SRAW || io.isa.SRAIW
	val Compare 		= io.isa.SLT || io.isa.SLTI || io.isa.SLTU || io.isa.SLTIU
	val Branches 		= io.isa.BEQ || io.isa.BNE || io.isa.BLT || io.isa.BGE || io.isa.BLTU || io.isa.BGEU
	val Jump_Link		= io.isa.JAL || io.isa.JALR
	val Sync 				= io.isa.FENCE || io.isa.FENCE_I
	val Environment = io.isa.ECALL || io.isa.EBREAK
	val CSR 				= io.isa.CSRRW || io.isa.CSRRS || io.isa.CSRRC || io.isa.CSRRWI || io.isa.CSRRSI || io.isa.CSRRCI
	val Loads 			= io.isa.LD || io.isa.LW || io.isa.LH || io.isa.LB || io.isa.LWU || io.isa.LHU || io.isa.LBU
	val Stores 			= io.isa.SD || io.isa.SW || io.isa.SH || io.isa.SB
	val Privileged 	= io.isa.MRET || io.isa.SRET || io.isa.WFI || io.isa.SFENCE_VMA
	val Legal 			= Arithmetic || Logical || Shifts || Compare || Branches || Jump_Link || Sync || Environment || CSR || Loads || Stores || Privileged
	
	io.isa.ILL := ~Legal

	val I = io.inst(31,20)
	val B = Cat(io.inst(31), io.inst(7), io.inst(30,25), io.inst(11,8), 0.U(1.W))
	val S = Cat(io.inst(31,25), io.inst(11,7))
	val U = Cat(io.inst(31,12), 0.U(12.W))
	val J = Cat(io.inst(31), io.inst(19,12), io.inst(20), io.inst(30,21), 0.U(1.W))
	val Z = io.inst(19,15)
	io.imm.I := SignExt(I, 64)
	io.imm.B := SignExt(B, 64)
	io.imm.S := SignExt(S, 64)
	io.imm.U := SignExt(U, 64)
	io.imm.J := SignExt(J, 64)
	io.imm.Z := ZeroExt(Z, 64)

	io.csr := CSR
	io.wen := Arithmetic || Logical || Shifts || Compare || Jump_Link || Loads || CSR

}